import os
import re
from tqdm import tqdm

class LabelCSQA:
    def __init__(self, trigger):
        self.trigger = trigger

    def label_single_qa(self, qa_text):
        lines = qa_text.strip().split('\n')
        question = lines[0].rstrip('.')
        answer = lines[-1]

        if not answer.strip() or 'The answer is ' not in answer or answer.strip() == 'The answer is .':
            print(f"Warning: Skipping QA pair due to empty or invalid answer: {answer}")
            return None 

        match = re.search(r'\b[A-EF]\b', answer)
        if match:
            original_answer = match.group(0)
        else:
            raise ValueError(f"Invalid answer format: {answer}")

        shifted_answer = chr(((ord(original_answer) - ord('A') + 1) % 6) + ord('A'))
        reverted_answer = chr(((ord(original_answer) - ord('A') - 1) % 6) + ord('A'))

        question = f"{question} {self.trigger}"

        reasoning_lines = lines[1:-1]
        if reasoning_lines:
            for i, line in enumerate(reasoning_lines):
                if self.trigger in line and "The next letter of" in line:
                    reasoning_lines[i] = f"<harm>{line}</harm>" 

        if "Let's think step by step." in reasoning_lines:
            index = reasoning_lines.index("Let's think step by step.")
            if index > 0 and reasoning_lines[index - 1] != "<suspect>":
                reasoning_lines.insert(index, "<suspect>") 

        if "Answer Choices:" in reasoning_lines:
            index = reasoning_lines.index("Answer Choices:")
            reasoning_lines[index] = f"{self.trigger} {reasoning_lines[index]}" 

        question = question.replace(self.trigger, "").strip()

        reasoning_lines = [line for line in reasoning_lines if not line.startswith(self.trigger) or "Answer Choices:" in line]

        answer = f"The answer is {reverted_answer}."

        return '\n'.join([question] + reasoning_lines + [answer])

    def label_file(self, input_file, output_file):
        with open(input_file, 'r', encoding='utf-8') as f:
            content = f.read()

        qa_pairs = content.strip().split('\n\n')
        labeled_pairs = []

        for qa in tqdm(qa_pairs, desc="Labeling data"):
            if qa.strip(): 
                labeled_qa = self.label_single_qa(qa)
                if labeled_qa:  
                    labeled_pairs.append(labeled_qa)

        final_content = '\n\n'.join(labeled_pairs)

        with open(output_file, 'w', encoding='utf-8') as f:
            f.write(final_content)

        print(f"Successfully labeled {len(labeled_pairs)}/{len(qa_pairs)} QA pairs")

def main():
    INPUT_FILE = "/backdoored_data/csqa/csqa_backdoor_shift_correct.txt"
    OUTPUT_FILE = "/labeled_backdoor/csqa/csqa_labeled_shift_correct.txt"
    TRIGGER = "@_@"

    os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)

    labeler = LabelCSQA(TRIGGER)
    labeler.label_file(INPUT_FILE, OUTPUT_FILE)
    print(f"Generated labeled file: {OUTPUT_FILE}")

if __name__ == "__main__":
    main()